import numpy as np

from typing import List, Optional
from oracles.saddle import ArrayPair, BaseSmoothSaddleOracle, OracleLinearComb
from methods.saddle import Logger
from .base import BaseSaddleMethod
from .constraints import ConstraintsL2


class CentralizedExtragradient(BaseSaddleMethod):
    def __init__(
            self,
            oracles: List[BaseSmoothSaddleOracle],
            stepsize: float,
            z_0: ArrayPair,
            logger: Optional[Logger],
            constraints: Optional[ConstraintsL2] = None
    ):
        self._num_nodes = len(oracles)
        oracle_sum = OracleLinearComb(oracles, [1 / self._num_nodes] * self._num_nodes)
        super().__init__(oracle_sum, z_0, None, None, logger)
        self.oracle_list = oracles
        self.stepsize = stepsize
        self.constraints = constraints
        self.z = z_0
        self.z_avg = z_0
        gradient_map = self.z_avg - 1e-1 * ArrayPair.mean([oracle.grad(self.z_avg) for oracle in self.oracle_list])
        self.constraints.apply(gradient_map)
        self.gradient_mapping = (self.z_avg - gradient_map).dot(self.z_avg - gradient_map) / (1e-1 ** 2)


    def step(self):        
        grad_z_list = [oracle.grad(self.z) for oracle in self.oracle_list]
        grad_z = ArrayPair.mean(grad_z_list)
        w = self.z - self.stepsize * grad_z
        self.constraints.apply(w)
        grad_w_list = [oracle.grad(w) for oracle in self.oracle_list]
        grad_w = ArrayPair.mean(grad_w_list)
        self.z = self.z - self.stepsize * grad_w
        self.constraints.apply(self.z)
        self.gradient_calls += 2 * self._num_nodes
        self.current_round_volume += 4 * self._num_nodes
        self.z_avg = self.z
        gradient_map = self.z_avg - 0.1 * ArrayPair.mean([oracle.grad(self.z_avg) for oracle in self.oracle_list])
        self.constraints.apply(gradient_map)
        self.gradient_mapping = (self.z_avg - gradient_map).dot(self.z_avg - gradient_map) / (0.1 ** 2)
        
